// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#include <boost/random/uniform_int_distribution.hpp>
#include <gtest/gtest.h>

#include <string>

#include <pollux/common/file/file_systems.h>

#include <pollux/testing/common/faulty_file_system.h>
#include <pollux/common/hyperloglog/sparse_hll.h>
#include <pollux/common/testutil/test_value.h>
#include <pollux/dwio/dwrf/writer/writer.h>
#include <pollux/exec/operator_trace_reader.h>
#include <pollux/exec/partition_function.h>
#include <pollux/exec/plan_node_stats.h>
#include <pollux/exec/table_writer.h>
#include <pollux/exec/trace_util.h>
#include <pollux/testing/exec/util/arbitrator_test_util.h>
#include <pollux/testing/exec/util/assert_query_builder.h>
#include <pollux/testing/exec/util/hive_connector_test_base.h>
#include <pollux/plan/plan_builder.h>
#include <pollux/common/file/temp_directory_path.h>
#include <pollux/serializers/presto_serializer.h>
#include <pollux/tool/trace/hash_join_replayer.h>
#include <pollux/tool/trace/trace_file_tool_runner.h>
#include <pollux/tool/trace/trace_replay_runner.h>

using namespace kumo::pollux;
using namespace kumo::pollux::core;
using namespace kumo::pollux::common;
using namespace kumo::pollux::exec;
using namespace kumo::pollux::exec::test;
using namespace kumo::pollux::connector;
using namespace kumo::pollux::connector::hive;
using namespace kumo::pollux::dwio::common;
using namespace kumo::pollux::common::testutil;
using namespace kumo::pollux::common::hll;
using namespace kumo::pollux::tests::utils;

namespace kumo::pollux::tool::trace::test {
class TraceFileToolTest : public HiveConnectorTestBase {
 protected:
  static void SetUpTestCase() {
    memory::MemoryManager::testingSetInstance({});
    HiveConnectorTestBase::SetUpTestCase();
    registerFaultyFileSystem();
    if (!isRegisteredVectorSerde()) {
      serializer::presto::PrestoVectorSerde::registerVectorSerde();
    }
    Type::registerSerDe();
    common::Filter::registerSerDe();
    connector::hive::HiveTableHandle::registerSerDe();
    connector::hive::LocationHandle::registerSerDe();
    connector::hive::HiveColumnHandle::registerSerDe();
    connector::hive::HiveInsertTableHandle::registerSerDe();
    connector::hive::HiveConnectorSplit::registerSerDe();
    core::PlanNode::registerSerDe();
    core::ITypedExpr::registerSerDe();
    registerPartitionFunctionSerDe();
  }

  void TearDown() override {
    probeInput_.clear();
    buildInput_.clear();
    HiveConnectorTestBase::TearDown();
  }

  struct PlanWithSplits {
    core::PlanNodePtr plan;
    core::PlanNodeId probeScanId;
    core::PlanNodeId buildScanId;
    melon::F14FastMap<core::PlanNodeId, std::vector<exec::Split>> splits;

    explicit PlanWithSplits(
        const core::PlanNodePtr& _plan,
        const core::PlanNodeId& _probeScanId = "",
        const core::PlanNodeId& _buildScanId = "",
        const melon::F14FastMap<
            core::PlanNodeId,
            std::vector<pollux::exec::Split>>& _splits = {})
        : plan(_plan),
          probeScanId(_probeScanId),
          buildScanId(_buildScanId),
          splits(_splits) {}
  };

  RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b) {
    std::vector<std::string> names = a->names();
    std::vector<TypePtr> types = a->children();

    for (auto i = 0; i < b->size(); ++i) {
      names.push_back(b->nameOf(i));
      types.push_back(b->childAt(i));
    }

    return ROW(std::move(names), std::move(types));
  }

  std::vector<RowVectorPtr>
  makeVectors(int32_t count, int32_t rowsPerVector, const RowTypePtr& rowType) {
    return HiveConnectorTestBase::makeVectors(rowType, count, rowsPerVector);
  }

  std::vector<Split> makeSplits(
      const std::vector<RowVectorPtr>& inputs,
      const std::string& path,
      memory::MemoryPool* writerPool) {
    std::vector<Split> splits;
    for (auto i = 0; i < 4; ++i) {
      const std::string filePath = fmt::format("{}/{}", path, i);
      writeToFile(filePath, inputs);
      splits.emplace_back(makeHiveConnectorSplit(filePath));
    }

    return splits;
  }

  PlanWithSplits createPlan(
      const std::string& tableDir,
      core::JoinType joinType,
      const std::vector<std::string>& probeKeys,
      const std::vector<std::string>& buildKeys,
      const std::vector<RowVectorPtr>& probeInput,
      const std::vector<RowVectorPtr>& buildInput) {
    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
    const std::vector<Split> probeSplits =
        makeSplits(probeInput, fmt::format("{}/probe", tableDir), pool());
    const std::vector<Split> buildSplits =
        makeSplits(buildInput, fmt::format("{}/build", tableDir), pool());
    core::PlanNodeId probeScanId;
    core::PlanNodeId buildScanId;
    const auto outputColumns = concat(
                                   as_row_type(probeInput_[0]->type()),
                                   as_row_type(buildInput_[0]->type()))
                                   ->names();
    auto plan = PlanBuilder(planNodeIdGenerator)
                    .tableScan(probeType_)
                    .capturePlanNodeId(probeScanId)
                    .hashJoin(
                        probeKeys,
                        buildKeys,
                        PlanBuilder(planNodeIdGenerator)
                            .tableScan(buildType_)
                            .capturePlanNodeId(buildScanId)
                            .planNode(),
                        /*filter=*/"",
                        outputColumns,
                        joinType,
                        false)
                    .capturePlanNodeId(traceNodeId_)
                    .planNode();
    return PlanWithSplits{
        plan,
        probeScanId,
        buildScanId,
        {{probeScanId, probeSplits}, {buildScanId, buildSplits}}};
  }

  core::PlanNodeId traceNodeId_;
  RowTypePtr probeType_{
      ROW({"t0", "t1", "t2", "t3"}, {BIGINT(), VARCHAR(), SMALLINT(), REAL()})};

  RowTypePtr buildType_{
      ROW({"u0", "u1", "u2", "u3"},
          {BIGINT(), INTEGER(), SMALLINT(), VARCHAR()})};
  std::vector<RowVectorPtr> probeInput_ = makeVectors(5, 100, probeType_);
  std::vector<RowVectorPtr> buildInput_ = makeVectors(3, 100, buildType_);

  const std::vector<std::string> probeKeys_{"t0"};
  const std::vector<std::string> buildKeys_{"u0"};
  const std::shared_ptr<TempDirectoryPath> testDir_ =
      TempDirectoryPath::create();
  const std::string tableDir_ =
      fmt::format("{}/{}", testDir_->getPath(), "table");
};

TEST_F(TraceFileToolTest, basic) {
  const auto planWithSplits = createPlan(
      tableDir_,
      core::JoinType::kInner,
      probeKeys_,
      buildKeys_,
      probeInput_,
      buildInput_);
  const auto traceRoot =
      fmt::format("{}/{}/traceRoot", testDir_->getPath(), "basic");
  std::shared_ptr<Task> task;
  auto tracePlanWithSplits = createPlan(
      tableDir_,
      core::JoinType::kInner,
      probeKeys_,
      buildKeys_,
      probeInput_,
      buildInput_);
  AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan);
  traceBuilder.maxDrivers(4)
      .config(core::QueryConfig::kQueryTraceEnabled, true)
      .config(core::QueryConfig::kQueryTraceDir, traceRoot)
      .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30)
      .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*")
      .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_);

  for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) {
    traceBuilder.splits(planNodeId, nodeSplits);
  }
  const auto result = traceBuilder.copyResults(pool(), task);
  const auto queryId = task->queryCtx()->queryId();
  const auto taskId = task->taskId();

  struct {
    std::string traceQueryId;
    std::string traceTaskId;

    std::string debugString() const {
      return fmt::format("queryId={}, taskId={}", traceQueryId, traceTaskId);
    }
  } testSettings[]{
      {"", ""},
      {queryId, ""},
      {queryId, taskId},
      {"", taskId},
  };

  for (const auto& testData : testSettings) {
    SCOPED_TRACE(testData.debugString());
    const auto testDir = TempDirectoryPath::create();
    const auto destRoot = testDir->getPath();
    FLAGS_source_root_dir = traceRoot;
    FLAGS_dest_root_dir = destRoot;
    FLAGS_trace_query_id = testData.traceQueryId;
    FLAGS_trace_task_id = testData.traceTaskId;
    FLAGS_trace_file_op = "copy";
    TraceFileToolRunner runner;
    if (FLAGS_trace_query_id.empty() && !FLAGS_trace_task_id.empty()) {
      POLLUX_ASSERT_THROW(
          runner.init(),
          "Trace query ID is empty but trace task ID is not empty");
      continue;
    }
    runner.init();
    runner.run();

    const auto replayingResult = HashJoinReplayer(
                                     destRoot,
                                     queryId,
                                     taskId,
                                     traceNodeId_,
                                     "HashJoin",
                                     "",
                                     0,
                                     executor_.get())
                                     .run();
    assertEqualResults({result}, {replayingResult});
  }
}
} // namespace kumo::pollux::tool::trace::test
